library(tidyverse)
library(maptools)
library(igraph)
library(parallel)
library(redist)
library(spdep)
library(gridExtra)
gpclibPermit()

ipw <- function(x, beta, pop){
    indpop <- which(x$distance_parity <= pop)
    indbeta <- which(x$beta_sequence == beta)
    inds <- intersect(indpop, indbeta)
    psi <- x$constraint_pop[inds]
    w <- 1 / exp(beta * psi)
    inds <- sample(inds, length(inds), replace = TRUE, prob = w)
    x <- x$partitions[,inds]
    return(x)
}
ben_theme <- function(){
    theme_classic() +
        theme(panel.background = element_blank(),
              panel.grid.major = element_blank(),
              panel.grid.minor = element_blank(),
              axis.line = element_line(colour = "black"),
              panel.border = element_rect(colour = "black", fill = NA, size = 1),
              strip.background = element_blank(),
              legend.position = "bottom", legend.title = element_blank(),
              plot.title = element_text(hjust = 0.5),
              plot.subtitle = element_text(hjust = .5))
}

fl_map <- readShapePoly("../../data/fl/FL.shp")

## Load nodes
nodes <- strsplit(readLines("../../data/largemap_enum_sample/map_sub_nodes.dat")," ")[[1]]
## nodes <- strsplit(readLines("../../data/big_validation_map/fl_sub_nodes.dat")," ")[[1]]

## First, subset down to right map
fl_sub <- fl_map[which(rownames(fl_map@data) %in% nodes),]

## Then, reorder nodes to be correct order to match - otherwise, won't
## line up properly with original map
fl_sub <- fl_sub[match(nodes, rownames(fl_sub@data)),]

## Convert factor columns to characters, because this particular data is messy
indx <- sapply(fl_sub@data, is.factor)
fl_sub@data[indx] <- lapply(
    fl_sub@data[indx], function(x) as.numeric(as.character(x))
)

## --------------
## Load solutions
## --------------
sols <- do.call("rbind", mclapply(1:10, function(x){
    dat <- read_csv(paste0("../../data/largemap_enum_sample/dissimilarity_", x, ".csv"))
}, mc.cores = 10))

## ---------------
## Run simulations
## ---------------
pop <- as.numeric(fl_sub@data$pop)
rep <- as.numeric(fl_sub@data$mccain)
nsims <- 500000
nchains <- 8

## Unconstrained
seg_unc <- unlist(
    mclapply(1:nchains, function(x){
        out_unc <- redist.mcmc(fl_sub, popvec = pop,
                               nsims = nsims, ndists = 2)
        return(coda::mcmc(redist.segcalc(out_unc, grouppop = rep, fullpop = pop)))
    }, mc.cores = detectCores())
)

## 5% parity
seg_05 <- unlist(
    mclapply(1:nchains, function(x){
        out_05 <- redist.mcmc(fl_sub, popvec = pop,
                               nsims = nsims, ndists = 2,
                               constraint = "population", beta = -25)
        out_05 <- ipw(out_05, beta = -25, pop = .05)
        return(coda::mcmc(redist.segcalc(out_05, grouppop = rep, fullpop = pop)))
    }, mc.cores = detectCores())
)

## 1% parity
seg_01 <- unlist(
    mclapply(1:nchains, function(x){
        out_01 <- redist.mcmc(fl_sub, popvec = pop,
                               nsims = nsims, ndists = 2,
                               constraint = "population", beta = -50)
        out_01 <- ipw(out_01, beta = -50, pop = .01)
        return(coda::mcmc(redist.segcalc(out_01, grouppop = rep, fullpop = pop)))
    }, mc.cores = detectCores())
)

## RSG
fn <- list.files("../../data/largemap_enum_sample/", pattern="rsg_*")
rsg_sims <- do.call("rbind", mclapply(paste0("../../data/largemap_enum_sample/", fn), read_csv, mc.cores = 20))

## Create dataframe for plot
df_plot <- data.frame(
    seg = c(sols$seg, sols$seg[sols$par <= .05], sols$seg[sols$par <= .01],
            seg_unc, seg_05, seg_01,
            rsg_sims$seg[rsg_sims$parity == "No Equal Population Constraint"],
            rsg_sims$seg[rsg_sims$parity == "5% Equal Population Constraint"],
            rsg_sims$seg[rsg_sims$parity == "1% Equal Population Constraint"]
            ),
    parity = c(rep("No Equal Population Constraint", nrow(sols)),
               rep("5% Equal Population Constraint", nrow(sols[sols$par <= .05,])),
               rep("1% Equal Population Constraint", nrow(sols[sols$par <= .01,])),
               rep("No Equal Population Constraint", length(seg_unc)),
               rep("5% Equal Population Constraint", length(seg_05)),
               rep("1% Equal Population Constraint", length(seg_01)),
               rep("No Equal Population Constraint", sum(rsg_sims$parity == "No Equal Population Constraint")),
               rep("5% Equal Population Constraint", sum(rsg_sims$parity == "5% Equal Population Constraint")),
               rep("1% Equal Population Constraint", sum(rsg_sims$parity == "1% Equal Population Constraint"))
               ),
    alg = c(rep("Truth", nrow(sols) + nrow(sols[sols$par <= .05,]) + nrow(sols[sols$par <= .01,])),
            rep("MCMC", length(seg_unc) + length(seg_05) + length(seg_01)),
            rep("RSG", nrow(rsg_sims)))
) %>%
    mutate(alg = factor(alg, levels = c("Truth", "MCMC", "RSG")),
           parity = factor(parity, levels = c("No Equal Population Constraint",
                                              "5% Equal Population Constraint",
                                              "1% Equal Population Constraint")))

## write_csv(df_plot[df_plot$alg != "Truth",], path = "../../data/largemap_enum_sample/mcmc_out.csv")
## mcmc_sims <- read_csv("../../data/largemap_enum_sample/mcmc_out.csv")
## fn <- list.files("../../data/largemap_enum_sample/", pattern="rsg_*")
## nums <- parse_number(fn)
## inds <- which(nums > 200)

## rsg_sims <- do.call("rbind", mclapply(paste0("../../data/largemap_enum_sample/", fn[inds]), read_csv, mc.cores = 20))
## rsg_sims_new <- data.frame(seg = c(rsg_sims$seg[rsg_sims$parity == "No Equal Population Constraint"],
##                                    rsg_sims$seg[rsg_sims$parity == "5% Equal Population Constraint"],
##                                    rsg_sims$seg[rsg_sims$parity == "1% Equal Population Constraint"]),
##                            parity = c(rep("No Equal Population Constraint", sum(rsg_sims$parity == "No Equal Population Constraint")),
##                                       rep("5% Equal Population Constraint", sum(rsg_sims$parity == "5% Equal Population Constraint")),
##                                       rep("1% Equal Population Constraint", sum(rsg_sims$parity == "1% Equal Population Constraint"))),
##                            alg = "RSG")
## truth_sims <- data.frame(seg = c(sols$seg, sols$seg[sols$par <= .05], sols$seg[sols$par <= .01]),
##                          parity = c(rep("No Equal Population Constraint", nrow(sols)),
##                                     rep("5% Equal Population Constraint", nrow(sols[sols$par <= .05,])),
##                                     rep("1% Equal Population Constraint", nrow(sols[sols$par <= .01,]))),
##                          alg = "Truth")

## df_plot <- bind_rows(mcmc_sims, truth_sims, rsg_sims_new) %>%
##     mutate(alg = factor(alg, levels = c("Truth", "MCMC", "RSG")),
##            parity = factor(parity, levels = c("No Equal Population Constraint",
##                                               "5% Equal Population Constraint",
##                                               "1% Equal Population Constraint")))

ggplot(df_plot, aes(seg, colour = alg, fill = alg,
                    alpha = alg, lty = alg)) +
    geom_density(lwd = 1.1) +
    facet_grid(~ parity) +
    theme_classic() +
    scale_colour_manual(values = c("grey", "black", "red")) +
    scale_fill_manual(values = c("grey", "white", "white")) +
    scale_linetype_manual(values = c(1, 1, 2)) + 
    scale_alpha_manual(values = c(1, 0, 0)) + 
    labs(x = "Republican Dissimilarity Index", y = "Density") +
    ben_theme() +
    theme(legend.position = c(.1, .75)) +
    ## guides(colour = guide_legend(nrow = 1, byrow = TRUE)) +
    ggsave("../../paper/figs/largemap_enumerate_sample_validation.pdf", height = 3, width = 8)

## ----------
## Create map
## ----------
fl_sub@data$id <- rownames(fl_sub@data)
fl_sub_points <- fortify(fl_sub, region = "id")
fl_sub_df <- plyr::join(fl_sub_points, fl_sub@data, by = "id")

## Map plot
map_plot <- ggplot(fl_sub_df) + 
    aes(long, lat, group = group, fill = pop) + 
    geom_polygon(colour = "black", size = .3) +
    geom_path(color="black", lwd = .05) + 
    theme_classic() +
    scale_fill_gradient(low = "white", high = "black") +
    coord_equal() +
    guides(fill = guide_legend(title = "Population")) +
    labs(x = "Longitude", y = "Latitude",
         title = "250-Precinct Validation Map") + 
    theme(panel.border = element_rect(colour = "black", fill=NA),
          plot.title = element_text(hjust = 0.5),
          legend.position = c(.85, .825),
          legend.title = element_text(size = 8), 
          legend.text = element_text(size = 8))

## Parity plot
parity_plot_df <- sols %>% filter(par <= .2) %>%
    mutate(
        parity_group = cut(par, breaks = seq(0, max(par), .01),
                           include.lowest= TRUE, labels = FALSE)
    ) %>%
    group_by(pg = parity_group / 100) %>%
    summarize(n = n()) %>%
    ungroup() %>%
    mutate(prop = n / nrow(sols))
numsols_df <- c(
    sum(sols$par <= .01)/nrow(sols),
    sum(sols$par <= .05)/nrow(sols),
    sum(sols$par <= .1)/nrow(sols)
)
text_out <- data.frame(text = paste0(paste0("< ", format(c(.01, .05, .10), nsmall = 2), " Parity: ", paste0(format(round(numsols_df, 4)*100, nsmall = 2), "%"), collapse = "\n")))

parity_plot <- ggplot(parity_plot_df, aes(pg - .005, prop)) +
    geom_bar(stat = "identity") +
    theme_classic() +
    geom_text(data = text_out, aes(x = .1, y = Inf, label = text),
              vjust = 1.2, hjust = 1.1) + 
    labs(x = "Distance from Population Parity",
         y = "Share of Sampled Plans",
         title = "Distribution of Population Parity on 250-Precinct Map") +
    scale_y_continuous(labels = scales::percent_format(accuracy = 1)) + 
    theme(panel.border = element_rect(colour = "black", fill=NA),
          plot.title = element_text(hjust = 0.5),
          legend.position = c(.05, .05))

pdf(file = "../../paper/figs/map_descriptives_250precinct.pdf", height = 5, width = 10.5)
grid.arrange(map_plot, parity_plot, widths = c(2, 2), nrow = 1)
dev.off()
